{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning a communication channel\n",
    "\n",
    "Normally, if you have a function you would like to compute\n",
    "across a connection, you would specify it with `function=my_func`\n",
    "in the `Connection` constructor.\n",
    "However, it is also possible to use error-driven learning\n",
    "to learn to compute a function online."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Create the model without learning\n",
    "\n",
    "We'll start by creating a connection between two populations\n",
    "that initially computes a very weird function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import nengo\n",
    "from nengo.processes import WhiteSignal\n",
    "from nengo.solvers import LstsqL2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nengo.Network()\n",
    "with model:\n",
    "    inp = nengo.Node(WhiteSignal(60, high=5), size_out=2)\n",
    "    pre = nengo.Ensemble(60, dimensions=2)\n",
    "    nengo.Connection(inp, pre)\n",
    "    post = nengo.Ensemble(60, dimensions=2)\n",
    "    conn = nengo.Connection(pre, post, function=lambda x: np.random.random(2))\n",
    "    inp_p = nengo.Probe(inp)\n",
    "    pre_p = nengo.Probe(pre, synapse=0.01)\n",
    "    post_p = nengo.Probe(post, synapse=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we run this model as is, we can see that the connection\n",
    "from `pre` to `post` doesn't compute much of value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(10.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[0], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[0], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[0], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 1\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[1], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[1], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[1], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 2\")\n",
    "plt.legend(loc='best');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Add in learning\n",
    "\n",
    "If we can generate an error signal, then we can minimize\n",
    "that error signal using the `nengo.PES` learning rule.\n",
    "Since it's a communication channel, we know the value that we want,\n",
    "so we can compute the error with another ensemble."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with model:\n",
    "    error = nengo.Ensemble(60, dimensions=2)\n",
    "    error_p = nengo.Probe(error, synapse=0.03)\n",
    "\n",
    "    # Error = actual - target = post - pre\n",
    "    nengo.Connection(post, error)\n",
    "    nengo.Connection(pre, error, transform=-1)\n",
    "\n",
    "    # Add the learning rule to the connection\n",
    "    conn.learning_rule_type = nengo.PES()\n",
    "\n",
    "    # Connect the error into the learning rule\n",
    "    nengo.Connection(error, conn.learning_rule)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can see the `post` population gradually learn to compute\n",
    "the communication channel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(10.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "plt.subplot(3, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[0], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[0], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[0], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 1\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[1], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[1], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[1], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 2\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(3, 1, 3)\n",
    "plt.plot(sim.trange(), sim.data[error_p], c='b')\n",
    "plt.ylim(-1, 1)\n",
    "plt.legend((\"Error[0]\", \"Error[1]\"), loc='best');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Does it generalize?\n",
    "\n",
    "If the learning rule is always working,\n",
    "the error will continue to be minimized.\n",
    "But have we actually generalized\n",
    "to be able to compute the communication channel\n",
    "without this error signal?\n",
    "Let's inhibit the `error` population after 10 seconds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inhibit(t):\n",
    "    return 2.0 if t > 10.0 else 0.0\n",
    "\n",
    "\n",
    "with model:\n",
    "    inhib = nengo.Node(inhibit)\n",
    "    nengo.Connection(inhib, error.neurons, transform=[[-1]] * error.n_neurons)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(16.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "plt.subplot(3, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[0], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[0], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[0], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 1\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[1], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[1], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[1], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 2\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(3, 1, 3)\n",
    "plt.plot(sim.trange(), sim.data[error_p], c='b')\n",
    "plt.ylim(-1, 1)\n",
    "plt.legend((\"Error[0]\", \"Error[1]\"), loc='best');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How does this work?\n",
    "\n",
    "The `nengo.PES` learning rule minimizes the same error online\n",
    "as the decoder solvers minimize with offline optimization.\n",
    "\n",
    "Let $\\mathbf{E}$ be an error signal.\n",
    "In the communication channel case, the error signal\n",
    "$\\mathbf{E} = \\mathbf{\\hat{x}} - \\mathbf{x}$;\n",
    "in other words, it is the difference between\n",
    "the decoded estimate of `post`, $\\mathbf{\\hat{x}}$,\n",
    "and the decoded estimate of `pre`, $\\mathbf{x}$.\n",
    "\n",
    "The PES learning rule on decoders is\n",
    "\n",
    "$$\\Delta \\mathbf{d_i} = -\\frac{\\kappa}{n} \\mathbf{E} a_i$$\n",
    "\n",
    "where $\\mathbf{d_i}$ are the decoders being learned,\n",
    "$\\kappa$ is a scalar learning rate, $n$ is the number of neurons\n",
    "in the `pre` population,\n",
    "and $a_i$ is the filtered activity of the `pre` population.\n",
    "\n",
    "However, many synaptic plasticity experiments\n",
    "result in learning rules that explain how\n",
    "individual connection weights change.\n",
    "We can multiply both sides of the equation\n",
    "by the encoders of the `post` population,\n",
    "$\\mathbf{e_j}$, and the gain of the `post`\n",
    "population $\\alpha_j$, as we do in\n",
    "Principle 2 of the NEF.\n",
    "This results in the learning rule\n",
    "\n",
    "$$\n",
    "\\Delta \\omega_{ij} = \\Delta \\mathbf{d_i} \\cdot \\mathbf{e_j} \\alpha_j\n",
    "  = -\\frac{\\kappa}{n} \\alpha_j \\mathbf{e_j} \\cdot \\mathbf{E} a_i\n",
    "$$\n",
    "\n",
    "where $\\omega_{ij}$ is the connection\n",
    "between `pre` neuron $i$ and `post` neuron $j$.\n",
    "\n",
    "The weight-based version of PES can be easily combined with\n",
    "learning rules that describe synaptic plasticity experiments.\n",
    "In Nengo, the `Connection.learning_rule_type` parameter accepts\n",
    "a list of learning rules.\n",
    "See [Bekolay et al., 2013](\n",
    "http://compneuro.uwaterloo.ca/publications/bekolay2013.html)\n",
    "for details on what happens when the PES learning rule is\n",
    "combined with an unsupervised learning rule."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How do the decoders / weights change?\n",
    "\n",
    "The equations above describe\n",
    "how the decoders and connection weights change\n",
    "as a result of the PES rule.\n",
    "But are there any general principles\n",
    "that we can say about how the rule\n",
    "modifies decoders and connection weights?\n",
    "Determining this requires analyzing\n",
    "the decoders and connection weights\n",
    "as they change over the course of a simulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with model:\n",
    "    weights_p = nengo.Probe(conn, 'weights', synapse=0.01, sample_every=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(4.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "plt.subplot(3, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[0], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[0], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[0], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 1\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[1], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[1], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[1], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 2\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(3, 1, 3)\n",
    "plt.plot(sim.trange(sample_every=0.01), sim.data[weights_p][..., 10])\n",
    "plt.ylabel(\"Decoding weight\")\n",
    "plt.legend((\"Decoder 10[0]\", \"Decoder 10[1]\"), loc='best');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with model:\n",
    "    # Change the connection to use connection weights instead of decoders\n",
    "    conn.solver = LstsqL2(weights=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(4.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "plt.subplot(3, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[0], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[0], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[0], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 1\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(sim.trange(), sim.data[inp_p].T[1], c='k', label='Input')\n",
    "plt.plot(sim.trange(), sim.data[pre_p].T[1], c='b', label='Pre')\n",
    "plt.plot(sim.trange(), sim.data[post_p].T[1], c='r', label='Post')\n",
    "plt.ylabel(\"Dimension 2\")\n",
    "plt.legend(loc='best')\n",
    "plt.subplot(3, 1, 3)\n",
    "plt.plot(sim.trange(sample_every=0.01), sim.data[weights_p][..., 10])\n",
    "plt.ylabel(\"Connection weight\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For some general principles governing how the decoders change,\n",
    "[Voelker, 2015](http://compneuro.uwaterloo.ca/publications/voelker2015.html)\n",
    "and [Voelker & Eliasmith, 2017](\n",
    "http://compneuro.uwaterloo.ca/publications/voelker2017c.html)\n",
    "give detailed analyses of the rule's dynamics.\n",
    "It's also interesting to observe qualitative trends;\n",
    "often a few strong connection weights will\n",
    "dominate the others,\n",
    "while decoding weights tend to\n",
    "change or not change together."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What is `post_synapse`?\n",
    "\n",
    "By default the `PES` object sets\n",
    "`post_synapse=Lowpass(tau=0.005)`.\n",
    "This is a lowpass filter with time-constant $\\tau = 5\\,\\text{ms}$\n",
    "that is applied to the activities $a_i$\n",
    "before computing each update $\\Delta {\\bf d}_i$.\n",
    "\n",
    "In general, longer time-constants\n",
    "smooth over the spiking activity to produce more constant updates,\n",
    "while shorter time-constants adapt more quickly\n",
    "to rapidly changing inputs.\n",
    "The right trade-off depends on\n",
    "the particular demands of the model.\n",
    "\n",
    "This `Synapse` object can also be\n",
    "any other linear filter (as are used in the `Connection` object);\n",
    "for instance, `post_synapse=Alpha(tau=0.005)`\n",
    "applies an alpha filter to the postsynaptic activity.\n",
    "This will have the effect of delaying the bulk of the activities\n",
    "by a rise-time of $\\tau$ before applying the update.\n",
    "This may be useful for situations\n",
    "where the error signal is delayed by the same amount,\n",
    "since the error signal should be synchronized\n",
    "with the same activities that produced said error."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
